import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import argparse
import os
import numpy as np
from tqdm import tqdm
from accelerate import Accelerator

from clipper import Clipper
from data import get_eeg_dls
import utils

# Import different model variants
from model_concat import DiffusionEEGModelWithConcat, ImageToEEGModelWithConcat
from model_add import DiffusionEEGModelWithAdd, ImageToEEGModelWithAdd


def parse_args():
    parser = argparse.ArgumentParser(description='Cross-modal fusion ablation experiment training script')
    
    # Data parameters
    parser.add_argument('--data_path', type=str, required=True, help='Dataset path')
    parser.add_argument('--subject', type=int, default=1, help='Subject number')
    parser.add_argument('--batch_size', type=int, default=16, help='Training batch size')
    parser.add_argument('--val_batch_size', type=int, default=32, help='Validation batch size')
    parser.add_argument('--num_workers', type=int, default=4, help='Data loader process count')
    
    # Model parameters
    parser.add_argument('--clip_variant', type=str, default='ViT-L/14', 
                       choices=['RN50', 'ViT-L/14', 'ViT-B/32', 'RN50x64'], help='CLIP model version')
    parser.add_argument('--eeg_channels', type=int, default=63, help='Number of EEG electrodes')
    parser.add_argument('--eeg_length', type=int, default=250, help='Number of EEG sampling points')
    parser.add_argument('--hidden_dim', type=int, default=768, help='Hidden layer dimension')
    parser.add_argument('--num_train_timesteps', type=int, default=1000, help='Diffusion training time steps')
    
    # Ablation experiment parameters
    parser.add_argument('--model_type', type=str, required=True,
                       choices=['concat', 'add'],
                       help='Model type: concat (concatenation), add (addition)')
    
    # Training parameters
    parser.add_argument('--epochs', type=int, default=50, help='Number of training epochs')
    parser.add_argument('--lr', type=float, default=1e-4, help='Learning rate')
    parser.add_argument('--weight_decay', type=float, default=1e-5, help='Weight decay')
    parser.add_argument('--warmup_steps', type=int, default=1000, help='Learning rate warmup steps')
    parser.add_argument('--gradient_clip', type=float, default=1.0, help='Gradient clipping')
    parser.add_argument('--loss_type', type=str, default='mse', choices=['mse'],
                       help='Loss function type (mse: MSE Loss)')
    
    # Other parameters
    parser.add_argument('--device', type=str, default='cuda', help='Device type')
    parser.add_argument('--gpu_id', type=int, default=None, help='Specify GPU number (e.g.: 0, 1, 2)')
    parser.add_argument('--seed', type=int, default=42, help='Random seed')
    parser.add_argument('--save_dir', type=str, default='./checkpoints_ablation', help='Model save directory')
    parser.add_argument('--log_interval', type=int, default=100, help='Log output interval')
    parser.add_argument('--experiment_name', type=str, default='', help='Experiment name, used to distinguish different experiments')
    
    return parser.parse_args()


def create_model(args, device):
    """
    Create corresponding model based on model type
    """
    # Print current device information being used
    if torch.cuda.is_available():
        print(f"Current device in use: {device}")
        print(f"Number of available GPUs: {torch.cuda.device_count()}")
        print(f"Current GPU: {torch.cuda.current_device()}")
        print(f"GPU name: {torch.cuda.get_device_name(torch.cuda.current_device())}")
    
    # Initialize CLIP model
    clip_model = Clipper(
        clip_variant=args.clip_variant,
        device=device
    )
    
    # Select corresponding diffusion model based on model type
    if args.model_type == 'concat':
        print('Creating Concatenation cross-modal fusion model...')
        diffusion_model = DiffusionEEGModelWithConcat(
            eeg_channels=args.eeg_channels,
            eeg_length=args.eeg_length,
            hidden_dim=args.hidden_dim,
            num_train_timesteps=args.num_train_timesteps,
            device=device
        )
        model = ImageToEEGModelWithConcat(clip_model, diffusion_model)
    elif args.model_type == 'add':
        print('Creating Addition cross-modal fusion model...')
        diffusion_model = DiffusionEEGModelWithAdd(
            eeg_channels=args.eeg_channels,
            eeg_length=args.eeg_length,
            hidden_dim=args.hidden_dim,
            num_train_timesteps=args.num_train_timesteps,
            device=device
        )
        model = ImageToEEGModelWithAdd(clip_model, diffusion_model)
    else:
        raise ValueError(f'Unsupported model type: {args.model_type}')
    
    return model


def train_epoch(model, dataloader, optimizer, criterion, accelerator, args, epoch):
    """Train for one epoch"""
    model.train()
    total_loss = 0
    diffusion_losses = []
    
    pbar = tqdm(dataloader, desc=f'Epoch {epoch+1}/{args.epochs}', disable=not accelerator.is_local_main_process)
    
    for batch_idx, batch_data in enumerate(pbar):
        eeg_data, image_data = batch_data
        
        # Data preprocessing - check data type and convert
        eeg_data = eeg_data.float()
        image_data = image_data.float()
        
        # Process EEG data dimensions [batch_size, 4, 63, 250] -> [batch_size, 63, 250]
        if eeg_data.dim() == 4:
            eeg_data = utils.average_eeg_trials(eeg_data)
        
        # Expand EEG data dimensions [batch_size, 63, 250] -> [batch_size, 1, 63, 250]
        if eeg_data.dim() == 3:
            eeg_data = eeg_data.unsqueeze(1)
        
        # Forward propagation
        outputs = model(image_data, eeg_data, mode='train')
        
        # Calculate Diffusion loss
        diffusion_loss = criterion(outputs['noise_pred'], outputs['noise'])
        
        # Total loss = Diffusion loss
        total_batch_loss = diffusion_loss
        
        # Backpropagation
        optimizer.zero_grad()
        accelerator.backward(total_batch_loss)
        
        # Gradient clipping
        if accelerator.sync_gradients:
            accelerator.clip_grad_norm_(model.parameters(), args.gradient_clip)
        
        optimizer.step()
        
        # Record loss
        total_loss += total_batch_loss.item()
        diffusion_losses.append(diffusion_loss.item())
        
        # Update progress bar
        loss_name = 'MSE' if args.loss_type == 'mse' else 'MSE'  # Default to MSE
        pbar.set_postfix({
            'Loss': f'{total_batch_loss.item():.4f}',
            loss_name: f'{diffusion_loss.item():.4f}'
        })
        
        # Log output
        if batch_idx % args.log_interval == 0 and accelerator.is_local_main_process:
            print(f'Epoch: {epoch+1}, Batch: {batch_idx}, '
                  f'Total Loss: {total_batch_loss.item():.4f}, '
                  f'Diffusion Loss ({loss_name}): {diffusion_loss.item():.4f}')
    
    avg_loss = total_loss / len(dataloader)
    avg_diffusion_loss = np.mean(diffusion_losses)
    
    return avg_loss, avg_diffusion_loss


def validate_epoch(model, dataloader, criterion, accelerator, args):
    """Validate for one epoch"""
    model.eval()
    total_loss = 0
    correlations = []
    
    with torch.no_grad():
        for batch_data in tqdm(dataloader, desc='Validating', disable=not accelerator.is_local_main_process):
            eeg_data, image_data = batch_data
            
            # Data preprocessing - check data type and convert
            eeg_data = eeg_data.float()
            image_data = image_data.float()
            
            # Process EEG data dimensions
            if eeg_data.dim() == 4:
                eeg_data = utils.average_eeg_trials(eeg_data)
            
            # Expand EEG data dimensions [batch_size, 63, 250] -> [batch_size, 1, 63, 250]
            if eeg_data.dim() == 3:
                eeg_data = eeg_data.unsqueeze(1)
            
            # Generate EEG signals
            outputs = model(image_data, mode='test')
            generated_eeg = outputs['generated_eeg']
            
            # Calculate reconstruction loss
            recon_loss = criterion(generated_eeg, eeg_data)
            total_loss += recon_loss.item()
            
            # Calculate correlation
            batch_correlations = utils.compute_correlation(generated_eeg, eeg_data)
            correlations.extend(batch_correlations.cpu().numpy())
    
    avg_loss = total_loss / len(dataloader)
    avg_correlation = np.mean(correlations)
    
    return avg_loss, avg_correlation


def main():
    args = parse_args()
    
    # If GPU ID is specified, set visible device
    if args.gpu_id is not None:
        import os
        os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id)
        print(f"Using GPU: {args.gpu_id}")
    
    # Initialize Accelerator
    accelerator = Accelerator()
    
    # Set random seed
    utils.seed_everything(args.seed)
    
    # Create save directory
    experiment_dir = os.path.join(args.save_dir, args.model_type)
    if args.experiment_name:
        experiment_dir = os.path.join(experiment_dir, args.experiment_name)
    model_dir = os.path.join(experiment_dir, f'subject{args.subject}')
    
    if accelerator.is_local_main_process:
        os.makedirs(model_dir, exist_ok=True)
    
    # Load data
    if accelerator.is_local_main_process:
        print('Loading data...')
    train_dl, val_dl = get_eeg_dls(
        subject=args.subject,
        data_path=args.data_path,
        batch_size=args.batch_size,
        val_batch_size=args.val_batch_size,
        num_workers=args.num_workers,
        seed=args.seed
    )
    
    # Initialize model
    if accelerator.is_local_main_process:
        print('Initializing model...')
    model = create_model(args, accelerator.device)
    
    # Optimizer
    optimizer = optim.AdamW(
        model.parameters(),
        lr=args.lr,
        weight_decay=args.weight_decay
    )
    
    # Loss function selection
    if args.loss_type == 'mse':
        criterion = nn.MSELoss()
        loss_name = 'MSE'
    
    if accelerator.is_local_main_process:
        print(f'Using loss function: {loss_name}')
    
    # Use accelerator prepare model, optimizer and data loader
    model, optimizer, train_dl, val_dl = accelerator.prepare(
        model, optimizer, train_dl, val_dl
    )
    
    # Training records
    train_losses = []
    val_losses = []
    correlations = []
    best_val_loss = float('inf')
    best_val_correlation = 0.0  # Save correlation corresponding to best validation loss
    
    if accelerator.is_local_main_process:
        print('Starting training...')
    for epoch in range(args.epochs):
        # Training
        train_loss, train_diff_loss = train_epoch(
            model, train_dl, optimizer, criterion, accelerator, args, epoch
        )
        
        # Validation
        val_loss, val_correlation = validate_epoch(
            model, val_dl, criterion, accelerator, args
        )
        
        # Record results
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        correlations.append(val_correlation)
        
        if accelerator.is_local_main_process:
            print(f'Epoch {epoch+1}/{args.epochs}:')
            print(f'  Training loss ({loss_name}): {train_loss:.4f} (Diffusion: {train_diff_loss:.4f})')
            print(f'  Validation loss ({loss_name}): {val_loss:.4f}')
            print(f'  Validation correlation: {val_correlation:.4f}')
        
        # Save best model (based on minimum validation loss)
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_val_correlation = val_correlation  # Save corresponding validation correlation
            if accelerator.is_local_main_process:
                model_filename = f'best_model_{args.model_type}_{args.loss_type}.pth'
                best_model_path = os.path.join(model_dir, model_filename)
                accelerator.save(accelerator.unwrap_model(model).state_dict(), best_model_path)
                
                # Save model metadata
                metadata_filename = f'best_model_{args.model_type}_{args.loss_type}_metadata.txt'
                metadata_path = os.path.join(model_dir, metadata_filename)
                with open(metadata_path, 'w', encoding='utf-8') as f:
                    f.write(f'Best Model Information:\n')
                    f.write(f'Model type: {args.model_type}\n')
                    f.write(f'Training epoch: {epoch + 1}\n')
                    f.write(f'Validation loss: {best_val_loss:.6f}\n')
                    f.write(f'Validation correlation: {best_val_correlation:.6f}\n')
                    f.write(f'Loss function type: {loss_name}\n')
                    f.write(f'Subject: {args.subject}\n')
                    f.write(f'Model file: {model_filename}\n')
                
                print(f'Saved best model ({args.model_type}, {loss_name}), validation loss: {best_val_loss:.4f}, corresponding validation correlation: {best_val_correlation:.4f}')
    
    # Training completed
    if accelerator.is_local_main_process:
        print('Training completed!')
        print(f'Best validation loss ({loss_name}): {best_val_loss:.4f}')
        print(f'Corresponding validation correlation: {best_val_correlation:.4f}')
        print(f'Model saved in: {model_dir}')
        
        # Save best validation loss and corresponding validation correlation to file
        result_filename = f'results_{args.model_type}_{args.loss_type}.txt'
        result_file = os.path.join(model_dir, result_filename)
        with open(result_file, 'w', encoding='utf-8') as f:
            f.write(f'Model type: {args.model_type}\n')
            f.write(f'Loss function type: {loss_name}\n')
            f.write(f'Best validation loss: {best_val_loss:.6f}\n')
            f.write(f'Corresponding validation correlation: {best_val_correlation:.6f}\n')
            f.write(f'Number of training epochs: {args.epochs}\n')
            f.write(f'Subject: {args.subject}\n')
        print(f'Best validation results saved to: {result_file}')
        
        # Save training history
        history_filename = f'history_{args.model_type}_{args.loss_type}.npz'
        history_file = os.path.join(model_dir, history_filename)
        np.savez(
            history_file,
            train_losses=np.array(train_losses),
            val_losses=np.array(val_losses),
            correlations=np.array(correlations)
        )
        print(f'Training history saved to: {history_file}')


if __name__ == '__main__':
    main()